import torch

from mydataset import *
from precessing import *
from Seq2SeqModel import *
import torch.nn.functional as F

sos_token =0
batch_size=1
input_lang,out_lang = creat_lang(500)
input_data,tag_input,tag_output = read_data(input_lang,out_lang,data_path="data/seq.data")

print("数据长度：",len(input_data))

Mydatset = MyDataset(input_data,tag_input,tag_output)

train_loader = DataLoader(Mydatset,batch_size=batch_size,shuffle=True)

EncoderModel = Encoder(input_lang.n_words,hidden_size=32)

DecoderModel = AttentionDencoder(output_size=out_lang.n_words, hidden_size=32)

crossentropyloss=nn.CrossEntropyLoss()
opt_config = [{'params': EncoderModel.parameters(), 'lr': 1e-4},
              {'params': DecoderModel.parameters(), 'lr': 1e-4}]
opt = torch.optim.Adam(opt_config,lr=1e-4)
for epoch in range(1):
    for data in train_loader:

        input_data, tag_input, tag_output = data



        encoder_output,hidden = EncoderModel(input_data,None)
        decoder_input = torch.tensor([sos_token]*input_data.shape[0], device=device)

        output_len=[]
        for i in range(MAX_LEN):

            output, hidden, attn_weights = DecoderModel(decoder_input,hidden,encoder_output)

            output_len.append(output)
            _,id = output.topk(1)
            #decoder_input = id.view(-1)
            decoder_input = tag_output[:,i]  # teacher_forcing

        # print(output_len)
        loss = 0
        for id,out in enumerate(output_len):

            loss+=crossentropyloss(out[:,0,:], tag_output[:,id])
        print(loss)
        opt.zero_grad()
        loss.backward()
        opt.step()

torch.save(EncoderModel.state_dict(),"savemode/EncoderModel.pkl")
torch.save(DecoderModel.state_dict(),"savemode/DecoderModel.pkl")




